import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import prettytable
import time
import os
import multiprocessing.pool as mpp
import multiprocessing as mp

from train import *

import argparse
from utils.config import Config
from tools.mask_convert import mask_save

def get_args():
    parser = argparse.ArgumentParser('description=Change detection of remote sensing images')
    parser.add_argument("-c", "--config", type=str, default="configs/STNet.py")
    parser.add_argument("--ckpt", type=str, default=None)
    parser.add_argument("--output_dir", type=str, default=None)
    return parser.parse_args()

if __name__ == "__main__":
    args = get_args()

    ckpt = args.ckpt
    if ckpt is None:
        ckpt = cfg.test_ckpt_path
    assert ckpt is not None

    if args.output_dir:
        base_dir = args.output_dir
    else:
        base_dir = os.path.dirname(ckpt)
    masks_output_dir = os.path.join(base_dir, "mask_rgb") 

    cfg = Config.fromfile(args.config)
    model = myTrain.load_from_checkpoint(ckpt, cfg = cfg)
    model = model.to('cuda')

    model.eval()

    metric_cfg_1 = cfg.metric_cfg1
    metric_cfg_2 = cfg.metric_cfg2
    
    test_oa=torchmetrics.Accuracy(**metric_cfg_1).to('cuda')
    test_prec = torchmetrics.Precision(**metric_cfg_2).to('cuda')
    test_recall = torchmetrics.Recall(**metric_cfg_2).to('cuda')
    test_f1 = torchmetrics.F1Score(**metric_cfg_2).to('cuda')
    test_iou=torchmetrics.JaccardIndex(**metric_cfg_2).to('cuda')

    results = []
    with torch.no_grad():
        test_loader = build_dataloader(cfg.dataset_config, mode='test')
        for input in tqdm(test_loader):

            raw_predictions, mask, img_id = model(input[0].cuda(), input[1].cuda()), input[2].cuda(), input[3]
            pred = raw_predictions.argmax(dim=1)

            test_oa(pred, mask)
            test_iou(pred, mask)
            test_prec(pred, mask)
            test_f1(pred, mask)
            test_recall(pred, mask)

            for i in range(raw_predictions.shape[0]):
                mask_real = mask[i].cpu().numpy()
                mask_pred = pred[i].cpu().numpy()
                mask_name = str(img_id[i])
                results.append((mask_real, mask_pred, masks_output_dir, mask_name))

    metrics = [test_prec.compute(),
               test_recall.compute(),
               test_f1.compute(),
               test_iou.compute()]
    
    total_metrics = [test_oa.compute().cpu().numpy(),
                     np.mean([item.cpu() for item in metrics[0]]),
                     np.mean([item.cpu() for item in metrics[1]]),
                     np.mean([item.cpu() for item in metrics[2]]),
                     np.mean([item.cpu() for item in metrics[3]])]

    result_table = prettytable.PrettyTable()
    result_table.field_names = ['Class', 'OA', 'Precision', 'Recall', 'F1_Score', 'IOU']

    for i in range(2):
        item = [i, '--']
        for j in range(len(metrics)):
            item.append(np.round(metrics[j][i].cpu().numpy(), 4))
        result_table.add_row(item)

    total = [np.round(v, 4) for v in total_metrics]
    total.insert(0, 'total')
    result_table.add_row(total)

    print(result_table)

    file_name = os.path.join(base_dir, "test_res.txt") 
    f = open(file_name,"a")
    current_time = time.strftime('%Y_%m_%d %H:%M:%S {}'.format(cfg.net),time.localtime(time.time()))
    f.write(current_time+'\n')
    f.write(str(result_table)+'\n')
 
    if not os.path.exists(masks_output_dir):
        os.makedirs(masks_output_dir)
    print(masks_output_dir)

    t0 = time.time()
    mpp.Pool(processes=mp.cpu_count()).map(mask_save, results)
    t1 = time.time()
    img_write_time = t1 - t0
    print('images writing spends: {} s'.format(img_write_time))